# train_hhmm.py  (deterministic)
import os
# ---- set thread env BEFORE importing numpy/sklearn/torch ----
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

import argparse
import random
import numpy as np

def _make_deterministic(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        torch.use_deterministic_algorithms(True, warn_only=True)
        if hasattr(torch.backends, "cudnn"):
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    except Exception:
        pass

from hhmm_lib import (
    load_pt_records,
    build_top_sequences,
    fit_hhmm_fixed_top,
    coerce_labels_to_ids,
)
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in_pt", required=True)
    ap.add_argument("--C", type=int, default=4, help="#categories (top)")
    ap.add_argument("--K", type=int, default=7, help="#regimes per category (bottom)")
    ap.add_argument("--iters", type=int, default=10)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--out_npz", default="hhmm_model.npz")
    ap.add_argument("--label_key", default="sentences_with_labels",
                    help="Per-step labels field used to anchor top categories.")
    ap.add_argument("--pca_dim", type=int, default=64, help="PCA target dim (<= D).")
    ap.add_argument("--subset", choices=["all", "correct", "incorrect"], default="all",
                    help="Which data to train on")
    args = ap.parse_args()

    _make_deterministic(args.seed)

    # ---- load + filter records ----
    recs_all = load_pt_records(args.in_pt)

    # Stabilize order deterministically
    def _rec_key(r):
        return (
            r.get("sample_id", None),
            r.get("i", None),
            r.get("id", None),
            r.get("prompt", "")[:32]
        )
    recs_all = sorted(recs_all, key=_rec_key)

    if args.subset == "all":
        recs = recs_all
    else:
        want_correct = (args.subset == "correct")
        recs = [r for r in recs_all if bool(r.get("is_correct", False)) == want_correct]
    print(f"Loaded {len(recs)} records after subset='{args.subset}' filtering.")

    # ---- sequences (always anchored) ----
    seqs = build_top_sequences(recs)
    def _seq_key(s):
        return (s.get("sample_id", None), s.get("i", None), len(s.get("steps", [])))
    seqs = sorted(seqs, key=_seq_key)

    before = len(seqs)
    seqs = [s for s in seqs if args.label_key in s]
    after = len(seqs)
    if after == 0:
        raise RuntimeError(
            f"No sequences contain '{args.label_key}'. "
            "Always-anchored training requires labels."
        )
    print(f"Anchored mode: kept {after}/{before} sequences with labels ({args.label_key}).")

    # sanity-check a few label payloads deterministically
    for s in seqs[:3]:
        _ = coerce_labels_to_ids(s.get(args.label_key))

    # ---- feature preprocessing (deterministic) ----
    X = np.concatenate([x for seq in seqs for x in seq["steps"]], axis=0)  # [N*L, D]
    D_in = X.shape[1]
    scaler = StandardScaler(copy=True, with_mean=True, with_std=True).fit(X)
    X_scaled = scaler.transform(X)
    k = min(args.pca_dim, D_in)
    pca = PCA(n_components=k, svd_solver="full", random_state=args.seed).fit(X_scaled)

    for seq in seqs:
        new_steps = []
        for x in seq["steps"]:
            z = scaler.transform(x)
            z = pca.transform(z)
            new_steps.append(z.astype(np.float64))
        seq["steps"] = new_steps

    # ---- Train HHMM (always passes label_key) ----
    print("Training HHMM ...")
    model = fit_hhmm_fixed_top(
        seqs,
        C=args.C,
        K=args.K,
        label_key=args.label_key,   # <- always anchored
        n_iter=args.iters,
        seed=args.seed,
        verbose=True,
    )

    # ---- Save model + preproc ----
    out = {
        "C": np.array([model.C], dtype=np.int32),
        "K": np.array([model.K], dtype=np.int32),
        "D": np.array([model.D], dtype=np.int32),
        "top_start": model.top.startprob,
        "top_trans": model.top.transmat,
        **{f"b{c}_start": model.bottom[c].startprob for c in range(model.C)},
        **{f"b{c}_trans": model.bottom[c].transmat  for c in range(model.C)},
        **{f"b{c}_means": model.bottom[c].means     for c in range(model.C)},
        **{f"b{c}_vars":  model.bottom[c].variances for c in range(model.C)},
        # preprocessing
        "prep_mean": scaler.mean_.astype(np.float64),
        "prep_scale": scaler.scale_.astype(np.float64),
        "prep_pca_components": pca.components_.astype(np.float64),
        "prep_pca_mean": pca.mean_.astype(np.float64),
        "prep_pca_explained_variance": pca.explained_variance_.astype(np.float64),
        "prep_pca_explained_variance_ratio": pca.explained_variance_ratio_.astype(np.float64),
        "prep_pca_singular_values": pca.singular_values_.astype(np.float64),
        # provenance
        "meta_subset": np.array(args.subset),
    }
    np.savez(args.out_npz, **out)
    print(f"Saved HHMM model to {args.out_npz} (subset={args.subset})")
    print("Top startprob:", np.round(model.top.startprob, 3))
    print("Top transmat (row-stochastic):")
    with np.printoptions(precision=3, suppress=True):
        print(model.top.transmat)

if __name__ == "__main__":
    main()